{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#--coding:utf-8--\n",
    "#获得模型信息的代码\n",
    "\n",
    "from keras.applications.resnet import ResNet50, preprocess_input\n",
    "from keras_applications.resnext import ResNeXt50\n",
    "from keras.layers import Dense, GlobalAveragePooling2D\n",
    "from keras.models import Model\n",
    "\n",
    "base_model = ResNeXt50(include_top=False,\n",
    "                       weights=None,\n",
    "                       input_shape=(224, 224, 3),\n",
    "                       backend=keras.backend,\n",
    "                       layers=keras.layers,\n",
    "                       models=keras.models,\n",
    "                       utils=keras.utils)\n",
    "\n",
    "x = base_model.output\n",
    "x = GlobalAveragePooling2D()(x)\n",
    "predictions = Dense(4, activation='softmax')(x)\n",
    "model = Model(inputs=base_model.input, outputs=predictions)\n",
    "\n",
    "model.summary()\n",
    "print('the number of layers in this model:' + str(len(model.layers)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --coding:utf-8--\n",
    "import os\n",
    "import sys\n",
    "import glob\n",
    "import matplotlib.pyplot as plt\n",
    "import keras\n",
    "from keras import __version__\n",
    "from keras.applications.inception_resnet_v2 import InceptionResNetV2, preprocess_input\n",
    "from keras.models import Model\n",
    "from keras.layers import Dense, GlobalAveragePooling2D\n",
    "from keras.preprocessing.image import ImageDataGenerator\n",
    "from keras.optimizers import SGD\n",
    "from keras.callbacks import ModelCheckpoint\n",
    "from keras.callbacks import LearningRateScheduler, ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau\n",
    "\n",
    "\n",
    "def get_nb_files(directory):\n",
    "    \"\"\"Get number of files by searching directory recursively\"\"\"\n",
    "    if not os.path.exists(directory):\n",
    "        return 0\n",
    "    cnt = 0\n",
    "    for r, dirs, files in os.walk(directory):\n",
    "        for dr in dirs:\n",
    "            cnt += len(glob.glob(os.path.join(r, dr + \"/*\")))\n",
    "    return cnt\n",
    "\n",
    "\n",
    "# 数据准备\n",
    "print(\"导入数据，设置基础参数...\")\n",
    "IM_WIDTH, IM_HEIGHT = 100, 100  #densenet指定的图片尺寸\n",
    "train_dir = '/kaggle/input/tomatolay/normal/train'  # 训练集数据路径\n",
    "val_dir = '/kaggle/input/tomatolay/normal/val'  # 验证集数据\n",
    "nb_classes = 10\n",
    "nb_epoch = 20\n",
    "batch_size = 64\n",
    "\n",
    "nb_train_samples = get_nb_files(train_dir)  # 训练样本个数\n",
    "nb_classes = len(glob.glob(train_dir + \"/*\"))  # 分类数\n",
    "nb_val_samples = get_nb_files(val_dir)  #验证集样本个数\n",
    "nb_epoch = int(nb_epoch)  # epoch数量\n",
    "batch_size = int(batch_size)\n",
    "\n",
    "#　图片增强\n",
    "print(\"数据增强...\")\n",
    "train_datagen = ImageDataGenerator(preprocessing_function=preprocess_input,\n",
    "                                   rotation_range=30,\n",
    "                                   width_shift_range=0.2,\n",
    "                                   height_shift_range=0.2,\n",
    "                                   shear_range=0.2,\n",
    "                                   zoom_range=0.2,\n",
    "                                   horizontal_flip=True)\n",
    "test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input,\n",
    "                                  rotation_range=30,\n",
    "                                  width_shift_range=0.2,\n",
    "                                  height_shift_range=0.2,\n",
    "                                  shear_range=0.2,\n",
    "                                  zoom_range=0.2,\n",
    "                                  horizontal_flip=True)\n",
    "\n",
    "# 训练数据与测试数据\n",
    "print(\"准备训练数据与测试数据...\")\n",
    "train_generator = train_datagen.flow_from_directory(train_dir,\n",
    "                                                    target_size=(IM_WIDTH,\n",
    "                                                                 IM_HEIGHT),\n",
    "                                                    batch_size=batch_size,\n",
    "                                                    class_mode='categorical')\n",
    "\n",
    "validation_generator = test_datagen.flow_from_directory(\n",
    "    val_dir,\n",
    "    target_size=(IM_WIDTH, IM_HEIGHT),\n",
    "    batch_size=batch_size,\n",
    "    class_mode='categorical')\n",
    "\n",
    "class_dict = train_generator.class_indices\n",
    "li = list(class_dict.keys())\n",
    "\n",
    "\n",
    "# 添加新层\n",
    "def add_new_last_layer(base_model, nb_classes):\n",
    "    \"\"\"\n",
    "  添加最后的层\n",
    "  输入\n",
    "  base_model和分类数量\n",
    "  输出\n",
    "  新的keras的model\n",
    "  \"\"\"\n",
    "    x = base_model.output\n",
    "    x = GlobalAveragePooling2D()(x)\n",
    "    predictions = Dense(nb_classes,\n",
    "                        activation='softmax')(x)  #new softmax layer\n",
    "    model = Model(inputs=base_model.inputs, outputs=predictions)\n",
    "    return model\n",
    "\n",
    "\n",
    "#搭建模型\n",
    "print(\"模型搭建中...\")\n",
    "model = InceptionResNetV2(weights='imagenet',\n",
    "                          input_shape=(100, 100, 3),\n",
    "                          include_top=False)\n",
    "model = add_new_last_layer(model, nb_classes)\n",
    "model.compile(optimizer=SGD(lr=0.001,\n",
    "                            momentum=0.9,\n",
    "                            decay=0.0001,\n",
    "                            nesterov=True),\n",
    "              loss='categorical_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "#设置回调\n",
    "print(\"设置回调...\")\n",
    "learning_rate_reduction = ReduceLROnPlateau(monitor='val_loss',\n",
    "                                            factor=0.1,\n",
    "                                            patience=10,\n",
    "                                            verbose=0,\n",
    "                                            mode='auto',\n",
    "                                            min_delta=0.0001,\n",
    "                                            cooldown=0,\n",
    "                                            min_lr=0)\n",
    "filepath = \"ResNeXt101.hdf5\"\n",
    "checkpoint = ModelCheckpoint(filepath,\n",
    "                             monitor='val_loss',\n",
    "                             verbose=1,\n",
    "                             save_best_only=True,\n",
    "                             mode='min')\n",
    "callbacks_list = [checkpoint, learning_rate_reduction]\n",
    "\n",
    "#开始训练AD\n",
    "print(\"开始训练\")\n",
    "history_ft = model.fit_generator(train_generator,\n",
    "                                 steps_per_epoch=nb_train_samples //\n",
    "                                 batch_size,\n",
    "                                 epochs=nb_epoch,\n",
    "                                 validation_data=validation_generator,\n",
    "                                 validation_steps=nb_val_samples // batch_size,\n",
    "                                 callbacks=callbacks_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dir= \"/kaggle/input/tomatolay/normal/test\"\n",
    "test_generator_samples = get_nb_files(test_dir)\n",
    "test_generator = test_datagen.flow_from_directory(test_dir,\n",
    "                                                    target_size=(IM_WIDTH,\n",
    "                                                                 IM_HEIGHT),\n",
    "                                                    batch_size=batch_size,\n",
    "                                                    class_mode='categorical',\n",
    "                                                     shuffle = False)\n",
    "test_loss, test_acc = model.evaluate_generator(test_generator, steps=test_generator_samples / batch_size)\n",
    "print('test acc: %.4f%%' % test_acc)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tf-gpu",
   "language": "python",
   "name": "tf-gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
